import math
from math import ceil
from typing import Sequence, Union, List
import numpy as np
import torch

MY_SEQ_TYPE = Union[np.ndarray, Sequence, List]


def __calc_optimistic_random_repetitions(sorted_preds, k, equality_loc):
    choose_from = sorted_preds[equality_loc:]
    selected_top = sorted_preds[:equality_loc]
    choose_size = np.abs(len(choose_from) - k)
    n_choose_k = math.factorial(len(choose_from)) / (math.factorial(k) * math.factorial(choose_size))
    num_repetitions = int(max(min(n_choose_k, ceil(len(choose_from) / k) * 10), 1000))
    return choose_from, num_repetitions, selected_top


def _get_top_k_names(names:  MY_SEQ_TYPE, values: MY_SEQ_TYPE, k: int, optimistic: bool):
    """
    Gets the top k names according to values, can return more than k results if optimistic
    :param names: names of data in order of values
    :param values: values for selecting k best
    :param k:
    :param optimistic: False for random first K, True will look at the k value and will return all names with equal or above the k_th value
    :return:
    """
    names_values = np.empty(len(names), dtype=[('name', 'U256'), ('value', float)])
    for idx, data in enumerate(zip(names, values)):
        names_values[idx] = data
    sort_ans = np.sort(names_values, order='value')[::-1]
    if k >= len(sort_ans):
        k = len(sort_ans)
    if optimistic:
        k_val = sort_ans['value'][k-1]
        better_k_index = np.argwhere(sort_ans['value'] > sort_ans['value'][k-1]).ravel()
        equality_loc = np.max(better_k_index) if len(better_k_index) > 0 else 0
        top_names = sort_ans[sort_ans['value'] >= k_val]['name']
    else:
        top_names = sort_ans['name'][:k]
        equality_loc = -1

    return top_names, equality_loc


def precision_at_k_selected_group(names: MY_SEQ_TYPE, y_true: MY_SEQ_TYPE, selected_models: MY_SEQ_TYPE, k: int, random_optimistic: bool) -> float:
    """
    Calculate precision at k, with the selected as input instead of models predicted results
    :param names: names of data in order of y_true
    :param y_true: labels values
    :param selected_models: top selected models names, in order of best models or the same size of K
    :param k:
    :param random_optimistic: False for random first K, True will look at the k value and will use all values equal or above
                for calculation
    :return:
    """
    sorted_true, equality_loc = _get_top_k_names(names, y_true, k, random_optimistic)
    return _precision_at_k_actual_calc(sorted_true, selected_models, equality_loc, k, random_optimistic)


def precision_at_k(names: MY_SEQ_TYPE, y_preds: MY_SEQ_TYPE, y_true: MY_SEQ_TYPE, k: int, random_optimistic: bool) -> float:
    """
    Calculates precision at k
    :param names: names of data in order of y_preds and y_true
    :param y_preds: predictions values
    :param y_true: labels values
    :param k:
    :param random_optimistic: False for random first K, True will look at the k value and will use all values equal or above
                for calculation
    :return:
    """
    sorted_preds, equality_loc = _get_top_k_names(names, y_preds, k, random_optimistic)
    sorted_true, _ = _get_top_k_names(names, y_true, k, random_optimistic)
    return _precision_at_k_actual_calc(sorted_true, sorted_preds, equality_loc, k, random_optimistic)


def _precision_at_k_actual_calc(sorted_true, sorted_preds, equality_loc, k, random_optimistic):
    if equality_loc == -1 or len(sorted_preds) < k:
        random_optimistic = False
    if random_optimistic and len(sorted_preds) > k:
        choose_from, num_repetitions, selected_top = __calc_optimistic_random_repetitions(sorted_preds, k, equality_loc)
        random_k_results = [np.sum(np.isin(
            np.concatenate((selected_top, np.random.choice(choose_from, size=k - len(selected_top), replace=False))),
            sorted_true)) for _ in range(num_repetitions)]
        selected_from_true = np.mean(random_k_results)
    else:
        selected_from_true = np.sum(np.isin(sorted_preds, sorted_true))

    return min(selected_from_true/k, 1.)   # Guard for optimistic case where there can be more than k values


def average_precision_at_k_selected_group(names: MY_SEQ_TYPE, y_true: MY_SEQ_TYPE, selected_models: MY_SEQ_TYPE, k: int,
                                          random_optimistic: bool) -> float:
    """
    Calculate Average precision at k, with the selected as input instead of models predicted results
    :param names: names of data in order of y_true
    :param y_true: labels values
    :param selected_models: top selected models names, in order of best models
    :param k:
    :param random_optimistic: False for random first K, True will look at the k value and will use all values equal or above
                for calculation
    :return:
    """
    precisions = [precision_at_k_selected_group(names, y_true=y_true, selected_models=selected_models, k=curr_k,
                                                random_optimistic=random_optimistic) for curr_k in range(1, k + 1)]
    return np.sum(precisions) / min(k, len(y_true))


def average_precision_at_k(names: MY_SEQ_TYPE, y_preds: MY_SEQ_TYPE, y_true: MY_SEQ_TYPE, k: int, random_optimistic: bool) -> float:
    """
    Calculates average precision at k
    :param names: names of data in order of y_preds and y_true
    :param y_preds: predictions values
    :param y_true: labels values
    :param k:
    :param random_optimistic: False for random first K, True will look at the k value and will use all values equal or above
                for calculation
    :return:
    """
    precisions = [precision_at_k(names, y_preds=y_preds, y_true=y_true, k=curr_k, random_optimistic=random_optimistic) for curr_k in range(1, k + 1)]
    return np.sum(precisions) / min(k, len(y_true))


def recall_at_k(names: MY_SEQ_TYPE, y_preds: MY_SEQ_TYPE, y_true: MY_SEQ_TYPE, k: int, th: float,
                random_optimistic: bool) -> float:
    relevant_models, _ = _get_top_k_names(names, y_true, int(len(y_true) * th), False)
    selected_models, equality_k_loc = _get_top_k_names(names, y_preds, k, random_optimistic)
    if equality_k_loc == -1 or len(selected_models) < k:
        random_optimistic = False
    if random_optimistic:
        choose_from, num_repetitions, selected_top = __calc_optimistic_random_repetitions(sorted_preds=selected_models,
                                                                                          k=k, equality_loc=equality_k_loc)
        random_k_relevant = [np.sum(
            np.isin(np.concatenate((selected_top, np.random.choice(choose_from, size=k - len(selected_top), replace=False))),
                    relevant_models))
            for _ in range(num_repetitions)]
        relevant_selected = np.mean(random_k_relevant)
    else:
        relevant_selected = np.sum(np.isin(selected_models, relevant_models))
    
    return relevant_selected/len(relevant_models)


def average_recall_at_k(names: MY_SEQ_TYPE, y_preds: MY_SEQ_TYPE, y_true: MY_SEQ_TYPE, k: int, th: float,
                        random_optimistic: bool) -> float:
    precisions = [recall_at_k(names, y_preds=y_preds, y_true=y_true, k=curr_k, random_optimistic=random_optimistic,th=th)
                  for curr_k in range(1, k + 1)]
    return np.sum(precisions) / int(len(y_true)*th)


def r2_loss(output: torch.Tensor, target: torch.Tensor):
    target_mean = torch.mean(target)
    ss_tot = torch.sum((target - target_mean) ** 2)
    ss_res = torch.sum((target - output) ** 2)
    r2 = 1 - (ss_res / ss_tot)
    return r2


def regression_log_loss(preds: torch.Tensor, targets: torch.Tensor, reduce=torch.mean):
    return -reduce(torch.log(torch.abs(targets-preds)))
